/**@@@+++@@@@******************************************************************
**
** Microsoft Windows Media
** Copyright (C) Microsoft Corporation. All rights reserved.
**
***@@@---@@@@******************************************************************
*/

#ifdef DX_WMDRM_USE_CRYS

#else
#include "rsaimpl.h"
/*
       File rsapimpl.c .  Version 12 October 2002

       This file has bignum interfaces
       unique to Microsoft's impl operating system.

       The main objective is to provide RSA support
       (encryption, decryption, key generation).
       The implementation uses the CRT to speed RSA decryption
       but avoids marginal optimizations which would
       significantly lengthen this open-source code.

       A separate file has impl interfaces needed
       by bignum, such as heap allocation and
       secure random number generation.
*/

/**********************************************************************
**
** Function:    big_endian_bytes_to_digits
**
** Synopsis:    Copy a big-endian DRM_BYTE array to a digit_t array.
**
** Arguments:   [barray]  - input array of DRM_BYTEs
**              [bitlen]  - length of array in bits
**              [darray]  - output array of digit_ts
**
** Returns:     TRUE  on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL big_endian_bytes_to_digits
        (   const DRM_BYTE  *barray,    /* IN */
            DWORDREGC       bitlen,     /* IN */
            digit_t         *darray)    /* OUT */
{
    DWORDREGC diglen = BITS_TO_DIGITS(bitlen);
    DWORDREG ibyte, idig;
    DRM_BOOL OK = TRUE;

    if (bitlen == 0)
    {
        return OK;
    }

    if (   NULL == barray 
        || NULL == darray)
    {
      TRACE(("NULL == barray || NULL == darray (%s:%d)\n", __FILE__, __LINE__));
        return FALSE;
    }

    mp_clear(darray, diglen, NULL);

    for (idig = 0; idig != diglen; idig++)
    {
        DWORDREGC bytremain = (bitlen + 7)/8 - RADIX_BYTES*idig;
        digit_t dvalue = 0;
        for (ibyte = 0; ibyte != MIN(RADIX_BYTES, bytremain); ibyte++)
        {
            dvalue ^= (digit_t)GET_BYTE(barray, bytremain - 1 - ibyte) << (8*ibyte);
        }
        darray[idig] = dvalue;
    } /* for idig */

    /* Strip high bits */
    darray[diglen-1] &= RADIXM1 >> (RADIX_BITS*diglen - bitlen);

    return OK;
} /* end big_endian_bytes_to_digits */


/**********************************************************************
**
** Function:    little_endian_bytes_to_digits
**
** Synopsis:    Copy a little-endian DRM_BYTE array to a digit_t array.
**
** Arguments:   [barray]  - input array of DRM_BYTEs
**              [bitlen]  - length of array in bits
**              [darray]  - output array of digit_ts
**
**
** Returns:     TRUE  on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL little_endian_bytes_to_digits
        (   const DRM_BYTE *barray,     /* IN */
            DWORDREGC       bitlen,     /* IN */
            digit_t         *darray)    /* OUT */
{
    DWORDREGC diglen = BITS_TO_DIGITS(bitlen);
    DWORDREG ibyte, idig;
    DRM_BOOL OK = TRUE;

    if (bitlen == 0)
    {
        return OK;
    }

    if (   NULL == barray 
        || NULL == darray)
    {
        return FALSE;
    }

    mp_clear(darray, diglen, NULL);

    for (idig = 0; idig != diglen; idig++)
    {
        DWORDREGC bytremain = (bitlen + 7)/8 - RADIX_BYTES*idig;
        digit_t dvalue = 0;
        for (ibyte = 0; ibyte != MIN(RADIX_BYTES, bytremain); ibyte++)
        {
            dvalue ^= (digit_t)GET_BYTE(barray, idig*RADIX_BYTES + ibyte) << (8*ibyte);
        }
        darray[idig] = dvalue;
    } /* for idig */


    /* Strip high bits */
    darray[diglen-1] &= RADIXM1 >> (RADIX_BITS*diglen - bitlen);

    return OK;
} /* end little_endian_bytes_to_digits */


/**********************************************************************
**
** Function:    digits_to_big_endian_bytes
**
** Synopsis:    Convert digit_t array to bytes, putting most significant byte first.
**
** Arguments:   [darray]  - input array of digit_ts
**              [bitlen]  - length of array in bits
**              [barray]  - output array of DRM_BYTEs
**
** Returns:     TRUE  on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL digits_to_big_endian_bytes
       (digit_tc    *darray,   /* IN */
        DWORDREGC    bitlen,   /* IN */
        DRM_BYTE    *barray)   /* OUT */
{
    DWORDREG ibyte, idig;
    DRM_BOOL OK = TRUE;

    if (   NULL == barray 
        || NULL == darray)
    {
        return FALSE;
    }

    for (idig = 0; idig != BITS_TO_DIGITS(bitlen); idig++)
    {
        digit_t dvalue = darray[idig];
        DWORDREGC bytremain = (bitlen + 7)/8 - RADIX_BYTES*idig;
        for (ibyte = 0; ibyte != MIN(bytremain, RADIX_BYTES); ibyte++)
        {
            PUT_BYTE(barray, bytremain - 1 - ibyte, (DRM_BYTE)(dvalue & 0xff));
            dvalue >>= 8;
        }
    } /* for idig */

    return OK;
} /* end digits_to_big_endian_bytes */


/**********************************************************************
**
** Function:    digits_to_little_endian_bytes
**
** Synopsis:    Convert digit_t array to bytes, putting least significant byte first.
**
** Arguments:   [darray]  - input array of digit_ts
**              [bitlen]  - length of array in bits
**              [barray]  - output array of DRM_BYTEs
**
**
** Returns:     TRUE  on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL digits_to_little_endian_bytes
       (digit_tc    *darray,    /* IN */
        DWORDREGC   bitlen,     /* IN */
        DRM_BYTE    *barray)    /* OUT */
{
    DWORDREG ibyte, idig;
    DRM_BOOL OK = TRUE;

    if (   NULL == barray 
        || NULL == darray)
    {
        return FALSE;
    }

    for (idig = 0; idig != BITS_TO_DIGITS(bitlen); idig++)
    {
        digit_t dvalue = darray[idig];
        DWORDREGC bytremain = (bitlen + 7)/8 - RADIX_BYTES*idig;
        for (ibyte = 0; ibyte != MIN(bytremain, RADIX_BYTES); ibyte++)
        {
            PUT_BYTE(barray, RADIX_BYTES*idig + ibyte, (DRM_BYTE)(dvalue & 0xff));
            dvalue >>= 8;
        }
    } /* for idig */
    return OK;
} /* end digits_to_little_endian_bytes */

/**********************************************************************
**
** Function:    rsa_key_internalize
**
** Synopsis:    Creates an internal (public) key based upon an existing 
**              external (private) key
**
** Arguments:   [prsaext]  - input pointer to an existing external key
**              [prsaint]  - output pointer to corresponding internal 
**                           key
**
**
** Returns:     TRUE on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL rsa_key_internalize(
            const external_rsa_key_t  *prsaext,    /* IN */
            internal_rsa_key_t        *prsaint)    /* OUT */
{
    DRM_BOOL OK = TRUE;

    if (   NULL == prsaext 
        || NULL == prsaint)
    {
        OK = FALSE;
        SetMpErrno_clue(MP_ERRNO_NULL_POINTER, "rsa_key_internalize", NULL);
        /* Should check here that prsaext is really an external private key. */
    } 
    else
    {
        DWORDREGC diglen_pubexp = BITS_TO_DIGITS(prsaext->bitlen_pubexp);
        DWORDREGC diglen_p1 = BITS_TO_DIGITS(prsaext->bitlen_primes[0]);
        DWORDREGC diglen_p2 = BITS_TO_DIGITS(prsaext->bitlen_primes[1]);
        DWORDREGC diglen_p12 = diglen_p1 + diglen_p2;

        /* TBD Perhaps create_modulus should let the application supply its temporaries. */
        /* Then there will be a single heap allocation per internal_rsa_key_t, */
        /* not three smaller ones. */
        digit_t *dtemps = digit_allocate(3*diglen_p12 + diglen_pubexp,
                                         "rsa_key_internalize", NULL);
        if (dtemps == digit_NULL)
        {
            OK = FALSE;
        }

        if (OK)
        {
            digit_t *temp1 = dtemps;    /* Overlaps modulus */
            DWORDREG ip, moduli_created = 0;

            prsaint->bitlen_modulus = prsaext->bitlen_modulus;
            prsaint->diglen_pubexp = (DRM_DWORD)diglen_pubexp;
            prsaint->free_me = dtemps;
            prsaint->modulus = dtemps;
            prsaint->pubexp = dtemps + 3*diglen_p12;
            OK = OK && big_endian_bytes_to_digits(prsaext->pubexp,
                                    prsaext->bitlen_pubexp, prsaint->pubexp);
            for (ip = 0; ip != 2; ip++)
            {
                digit_t *temp1 = prsaint->modulus;
                DWORDC bitlen_p = prsaext->bitlen_primes[ip];

                prsaint->privexps[ip] = dtemps + diglen_p12 + ip*diglen_p1;
                prsaint->chineses[ip] = prsaint->privexps[ip] + diglen_p12;

                OK = OK && big_endian_bytes_to_digits(prsaext->primes[ip],
                                            bitlen_p, temp1);
                OK = OK && create_modulus(temp1,
                               BITS_TO_DIGITS(prsaext->bitlen_primes[ip]),
                               FROM_RIGHT, &prsaint->moduli[ip], NULL, NULL);
                if (OK)
                {
                    moduli_created++;
                }

                OK = OK && big_endian_bytes_to_digits(prsaext->privexps[ip],
                                            bitlen_p, prsaint->privexps[ip]);
            } /* for ip */

            if (OK)
            {
                DWORDREG lgcd = 0;
                OK = OK && mp_gcdex(prsaint->moduli[0].modulus, diglen_p1,
                                    prsaint->moduli[1].modulus, diglen_p2,
                                    prsaint->chineses[1], prsaint->chineses[0],
                                    temp1, digit_NULL, &lgcd, digit_NULL, NULL);
                if (OK && compare_immediate(temp1, 1, lgcd) != 0)
                {
                    OK = FALSE;
                    SetMpErrno_clue(MP_ERRNO_INVALID_DATA,
                        "rsa_key_internalize, GCD(p1, p2) <> 1", NULL);
                }

                mp_clear(prsaint->modulus, diglen_p12, NULL);

                /* Possibly insert leading zero */
                OK = OK && big_endian_bytes_to_digits(prsaext->modulus,
                                                      prsaext->bitlen_modulus,
                                                      prsaint->modulus);
            }
            while (!OK && moduli_created != 0)
            {
                moduli_created--;
                uncreate_modulus(&prsaint->moduli[moduli_created], NULL);
            }
        } /* if */

        if (!OK && dtemps != digit_NULL)
        {
            Free_Temporaries(dtemps, NULL);
        }
    } /* if */
    return OK;
} /* rsa_key_internalize */


/**********************************************************************
**
** Function:    RSA_Encrypt
**
** Synopsis:    Encrypts a buffer of plain text using an external key
**
** Arguments:   [pRSAkey]            - pointer to external (public) key
**              [pbPlainText]        - pointer to buffer holding the plaintext
**                                     plaintext must be smaller than modulus
**              [nPlainTextLength]   - length of plaintext
**              [pbCipherText]       - output buffer containing ciphertext
**              [pcCipherTextLength] - length of output buffer
**
** Returns:     none
**
**********************************************************************/
DRM_VOID 
RSA_Encrypt(
    IN external_rsa_key_t* pRSAkey,
    IN DRM_BYTE* pbPlainText,
    IN DRM_DWORD nPlainTextLength,
    OUT DRM_BYTE* pbCipherText,
    IN OUT DRM_DWORD* pcCipherTextLength
    )
{
    DRM_BOOL status = 0;

    if (   NULL == pRSAkey 
        || NULL == pbPlainText
        || NULL == pbCipherText
        || NULL == pcCipherTextLength)
    {
        return;
    }

    if (nPlainTextLength && pcCipherTextLength)
    {
        status = rsa_encryption(pRSAkey, pbPlainText, pbCipherText);
    }
}


/**********************************************************************
**
** Function:    rsa_encryption
**
** Synopsis:    Encrypt msgin -> msgout, using the RSA public key in prsaext
**              This funtion is an intermediate to RSA_Encrypt 
**
** Arguments:   [prsaext]   - pointer to external rsa key
**              [msgin]     - input buffer
**              [msgout]    - ouput buffer
**
** Returns:     TRUE on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL rsa_encryption
       (const external_rsa_key_t    *prsaext,   /* IN */
        const DRM_BYTE              *msgin,     /* IN */
                                                /*     prsaext->bitlen_modulus bits */
        DRM_BYTE                    *msgout)    /* OUT */
/*
     This code encrypts msgin, using the RSA public key in prsaext.
     The output is stored in msgout.
     msgin and msgout are big-endian DRM_BYTE arrays.
*/
{
    DRM_BOOL OK = TRUE;
    digit_t *dtemps = digit_NULL;
    DWORDREG bitlen_pubexp = 0, bitlen_mod = 0;
    DWORDREG diglen_pubexp = 0, diglen_mod = 0;

    if (   NULL == prsaext
        || NULL == msgin
        || NULL == msgout)
    {
        OK = FALSE;
        SetMpErrno_clue(MP_ERRNO_NULL_POINTER, "rsa_encryption", NULL);

    /* Do other validation checks here */
    } else {
        bitlen_pubexp = prsaext->bitlen_pubexp;
        bitlen_mod    = prsaext->bitlen_modulus;
        diglen_pubexp = BITS_TO_DIGITS(bitlen_pubexp);
        diglen_mod    = BITS_TO_DIGITS(bitlen_mod);

        dtemps = digit_allocate(diglen_mod + diglen_pubexp, "rsa_encryption", NULL);
        if (dtemps == digit_NULL) OK = FALSE;
    }
    if (OK)
    {
        digit_t *dmsg        =  dtemps;              /* Length diglen_mod */
        digit_t *dexponent   =  dtemps + diglen_mod; /* Length diglen_pubexp */
        DRM_BOOL modulus_created;
        mp_modulus_t modulo;

        OK = OK && big_endian_bytes_to_digits(prsaext->modulus,
                                              bitlen_mod, dmsg);
	if(!OK) TRACE(("big_endian_bytes_to_digits 1 failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
	
        OK = OK && create_modulus(dmsg, diglen_mod, FROM_RIGHT, &modulo, NULL, NULL);
	if(!OK) TRACE(("create_modulus failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
        modulus_created = OK;

        OK = OK && big_endian_bytes_to_digits(prsaext->pubexp,
                                              bitlen_pubexp, dexponent);
	if(!OK) TRACE(("big_endian_bytes_to_digits 2 failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;

        OK = OK && big_endian_bytes_to_digits(msgin, bitlen_mod, dmsg);
	if(!OK) TRACE(("big_endian_bytes_to_digits 3 failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
        OK = OK && to_modular(dmsg, diglen_mod, dmsg, &modulo, NULL);
	if(!OK) TRACE(("to_modular failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
        OK = OK && mod_exp(dmsg, dexponent, diglen_pubexp, dmsg, &modulo, NULL);
	if(!OK) TRACE(("mod_exp failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
        OK = OK && from_modular(dmsg, dmsg, &modulo, NULL);
	if(!OK) TRACE(("from_modular failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;

        OK = OK && digits_to_big_endian_bytes(dmsg, bitlen_mod, msgout);
	if(!OK) TRACE(("digits_to_big_endian_bytes failed (mp_errno = %d)\n", GetMpErrno(NULL))) ;
        if (modulus_created)
        {
            uncreate_modulus(&modulo, NULL);
        }
    }

    if (dtemps != digit_NULL)
    {
        Free_Temporaries(dtemps, NULL);
    }

    return OK;
} /* rsa_encryption */


/**********************************************************************
**
** Function:    RSA_Decrypt
**
** Synopsis:    Decrypts a ciphertext buffer using an internal key
**
** Arguments:   [pRSAkey]               - pointer to the internal key
**              [pbCipherText]          - pointer to the input buffer with the ciphertext
**              [nCipherTextLength]     - length of the ciphertext
**              [pbDecryptedData]       - pointer to the output plaintext
**              [pcDecryptedDataLength] - length of the plaintext
**
** Returns:     none
**
**********************************************************************/
DRM_VOID 
RSA_Decrypt(
    IN external_rsa_key_t* pRSAkey,
    IN DRM_BYTE* pbCipherText,
    IN DRM_DWORD nCipherTextLength,
    OUT DRM_BYTE* pbDecryptedData,
    IN OUT DRM_DWORD* pcDecryptedDataLength
    )
{
    DRM_BOOL status = 0;
    internal_rsa_key_t pRSAKey_internal;

    if (   NULL == pRSAkey
        || NULL == pbCipherText
        || NULL == pbDecryptedData
        || NULL == pcDecryptedDataLength)
    {
        return;
    }

    status = rsa_key_internalize(pRSAkey, &pRSAKey_internal);

    if (status && (nCipherTextLength && pcDecryptedDataLength))
    {
        status = rsa_decryption(&pRSAKey_internal, pbCipherText, pbDecryptedData);
    }
}


/**********************************************************************
**
** Function:    rsa_decryption
**
** Synopsis:    Decrypt msgin -> msgout, using the RSA public key in prsaext
**              This funtion is an intermediate to RSA_Decrypt 
**
** Arguments:   [prsaint]   - pointer to internal key
**              [msgin]     - input buffer with the ciphertext
**              [msgout]    - output buffer with the plaintext
**
**
** Returns:     TRUE on success
**              FALSE on error
**
**********************************************************************/
DRM_BOOL rsa_decryption
       (const internal_rsa_key_t    *prsaint,   /* IN/OUT */
        const DRM_BYTE              *msgin,     /* IN, length */
                                                /*    prsaext->bitlen_modulus bits */
        DRM_BYTE                    *msgout)    /* OUT */
{
/*
     This code decrypts msgin, using the RSA private key in prsaint.
     The output is stored in msgout.
     msgin and msgout are big-endian DRM_BYTE arrays.

     Using the Chinese Remainder Theorem, we reduce the original message
     (copied into dmsg[1]) modulo the two private primes p1 and p2, getting
     remainders in1 and in2 (say).  Form out1 = in1^(private exponent 1) mod p1
     and out2 == in2^(private exponent 2) mod p2.
     The desired output will have remainder out1 modulo p1 and out2 modulo p2.
     Such a result is p2*(out1/p2 mod p1) + p1*(out2/p1 mod p2).
     The final sum (of dmsg[0] and dmsg[1]) is done modulo p1*p2 = modulus.
*/
    DRM_BOOL OK = TRUE;

    if (   NULL == prsaint
        || NULL == msgin
        || NULL == msgout)
    {
        OK = FALSE;
        SetMpErrno_clue(MP_ERRNO_NULL_POINTER, "rsa_decryption", NULL);

    /* Do other validation checks here, */
    /* such as ensuring we have private (not public) key */
    }
    else
    {
        digit_t dtemps[82];
        mp_modulus_tc *moduli = prsaint->moduli;    /* Shorthand */
        DWORDREGC diglen_modulus = BITS_TO_DIGITS(prsaint->bitlen_modulus);

        if (moduli[0].length + moduli[1].length - diglen_modulus > 1)
        {
            /* N.B. Unsigned compare */
            OK = FALSE;
            SetMpErrno_clue(MP_ERRNO_INVALID_DATA, "rsa_decryption", NULL);
        }

        if (OK)
        {
            DWORDREG ip;
            digit_t *dmsg[2];
            digit_t *res = &dtemps[2*(diglen_modulus + 1)];
            /* Length diglen_plonger; */
            dmsg[0] = &dtemps[0];
            dmsg[1] = &dtemps[diglen_modulus + 1];

            OK = OK && big_endian_bytes_to_digits(msgin,
                                    prsaint->bitlen_modulus, dmsg[1]);
            for (ip = 0; OK && ip != 2; ip++)
            {
                OK = OK && to_modular(dmsg[1], diglen_modulus, res, &moduli[ip], NULL)
                        && mod_exp(res, prsaint->privexps[ip],
                                   moduli[ip].length, res, &moduli[ip], NULL)
                        && mod_mul(res, prsaint->chineses[ip], res,
                                   &moduli[ip], digit_NULL, NULL)

                        /* N.B.  from_modular call omitted since chineses[ip] lacks FROM_RIGHT scaling, */
                        && multiply(res, moduli[ip].length,
                                    moduli[1-ip].modulus,
                                    moduli[1-ip].length, dmsg[ip], NULL);
            } /* for ip */
            OK = OK && add_mod(dmsg[0], dmsg[1], dmsg[0],
                               prsaint->modulus, diglen_modulus, NULL)
                    && digits_to_big_endian_bytes(dmsg[0],
                                          prsaint->bitlen_modulus, msgout);
        } /* if */
    }

    return OK;
} /* end rsa_decryption */

#endif //#ifndef DX_WMDRM_USE_CRYS
